{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "dog-breed-classification.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hnEav91sjMfr"
      },
      "source": [
        "This notebook will train a dog breed classification model, using the pretrained resnet50 as the feature extractor, which classifies 133 dog breeds. \n",
        "\n",
        "The following dataset is dog breed dataset from Udacity, the trainig code has been borrowd from [Pytorch vision transfer learning tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1m9nsjHe4s04"
      },
      "source": [
        "!wget https://s3-us-west-1.amazonaws.com/udacity-aind/dog-project/dogImages.zip"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dbTFCkZr5AeC"
      },
      "source": [
        "!unzip dogImages.zip"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m9m5OV2ckm_s"
      },
      "source": [
        "Runing sanity check to verify the folder and number of images in them. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "89sRowH_QcFc"
      },
      "source": [
        "!ls dogImages/valid/005.Alaskan_malamute"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w7Si8tpTlqXr"
      },
      "source": [
        ""
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "68_4WoR0LpTB"
      },
      "source": [
        "%cd dogImages/valid\n",
        "!ls | wc -l\n",
        "%cd -"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "G6AmfqZmRMAv"
      },
      "source": [
        "from PIL import ImageFile\n",
        "ImageFile.LOAD_TRUNCATED_IMAGES = True"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OS-xSMdYM0W-"
      },
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.utils.data\n",
        "import torch.nn.functional as F\n",
        "import torchvision\n",
        "import torchvision.models as models\n",
        "from torchvision import transforms\n",
        "from PIL import Image\n",
        "import matplotlib.pyplot as plt\n",
        "from torch.optim import lr_scheduler\n",
        "from torchvision import datasets, models, transforms\n",
        "import os\n",
        "import time\n",
        "import copy"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "21FLnitamRPJ"
      },
      "source": [
        "Setting the dataset, data loader and the image transforms for different sets."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qUQ_tNYyMupx"
      },
      "source": [
        "data_transforms = {\n",
        "    'train': transforms.Compose([\n",
        "        transforms.RandomResizedCrop(224),\n",
        "        transforms.RandomHorizontalFlip(),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
        "    ]),\n",
        "    'valid': transforms.Compose([\n",
        "        transforms.Resize(256),\n",
        "        transforms.CenterCrop(224),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
        "    ]),\n",
        "    'test': transforms.Compose([\n",
        "        transforms.Resize(256),\n",
        "        transforms.CenterCrop(224),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
        "    ]),\n",
        "}\n",
        "data_dir = '/content/dogImages/'\n",
        "image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),\n",
        "                                          data_transforms[x])\n",
        "                  for x in ['train', 'valid', 'test']}\n",
        "dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32,\n",
        "                                             shuffle=True, num_workers=2)\n",
        "              for x in ['train', 'valid','test']}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "txbJ7AR5mmsY"
      },
      "source": [
        "Loading the pretrained resnet50 from torch-hub and changing the last layer output to the number of dog breeds in the dataset, 133. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FBaiDmo5NtBk"
      },
      "source": [
        "pretrained_model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)\n",
        "\n",
        "for name, param in pretrained_model.named_parameters():\n",
        "    if(\"bn\" not in name):\n",
        "        param.requires_grad = False\n",
        "        \n",
        "num_ftrs = pretrained_model.fc.in_features\n",
        "\n",
        "pretrained_model.fc = nn.Linear(num_ftrs, 133)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LItRc9k0m5re"
      },
      "source": [
        "Setting the optimizer and learning rate scheduler."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9U9_D_3xPQcB"
      },
      "source": [
        "optimizer = optim.Adam(pretrained_model.parameters(), lr=0.001)\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "# Decay LR by a factor of 0.1 every 7 epochs\n",
        "exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KTsPsCounHsG"
      },
      "source": [
        "Define the training loop."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "l-QYory0PXmk"
      },
      "source": [
        "def train_model(model, criterion, optimizer, scheduler, num_epochs=25):\n",
        "    since = time.time()\n",
        "\n",
        "    best_model_wts = copy.deepcopy(model.state_dict())\n",
        "    best_acc = 0.0\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
        "        print('-' * 10)\n",
        "\n",
        "        # Each epoch has a training and validation phase\n",
        "        for phase in ['train', 'valid']:\n",
        "            if phase == 'train':\n",
        "                model.train()  # Set model to training mode\n",
        "            else:\n",
        "                model.eval()   # Set model to evaluate mode\n",
        "\n",
        "            running_loss = 0.0\n",
        "            running_corrects = 0\n",
        "\n",
        "            # Iterate over data.\n",
        "            for inputs, labels in dataloaders[phase]:\n",
        "                inputs = inputs.to(device)\n",
        "                labels = labels.to(device)\n",
        "                # zero the parameter gradients\n",
        "                optimizer.zero_grad()\n",
        "\n",
        "                # forward\n",
        "                # track history if only in train\n",
        "                with torch.set_grad_enabled(phase == 'train'):\n",
        "                    outputs = model(inputs)\n",
        "                    _, preds = torch.max(outputs, 1)\n",
        "                    loss = criterion(outputs, labels)\n",
        "\n",
        "                    # backward + optimize only if in training phase\n",
        "                    if phase == 'train':\n",
        "                        loss.backward()\n",
        "                        optimizer.step()\n",
        "\n",
        "                # statistics\n",
        "                running_loss += loss.item() * inputs.size(0)\n",
        "                running_corrects += torch.sum(preds == labels.data)\n",
        "            if phase == 'train':\n",
        "                scheduler.step()\n",
        "\n",
        "            epoch_loss = running_loss / dataset_sizes[phase]\n",
        "            epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
        "\n",
        "            print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n",
        "                phase, epoch_loss, epoch_acc))\n",
        "\n",
        "            # deep copy the model\n",
        "            if phase == 'valid' and epoch_acc > best_acc:\n",
        "                best_acc = epoch_acc\n",
        "                best_model_wts = copy.deepcopy(model.state_dict())\n",
        "\n",
        "        print()\n",
        "\n",
        "    time_elapsed = time.time() - since\n",
        "    print('Training complete in {:.0f}m {:.0f}s'.format(\n",
        "        time_elapsed // 60, time_elapsed % 60))\n",
        "    print('Best val Acc: {:4f}'.format(best_acc))\n",
        "\n",
        "    # load best model weights\n",
        "    model.load_state_dict(best_model_wts)\n",
        "    return model"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rU5TcTsBPYvv"
      },
      "source": [
        "if torch.cuda.is_available():\n",
        "    device = torch.device(\"cuda\") \n",
        "else:\n",
        "    device = torch.device(\"cpu\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TcpTifrVPgGE"
      },
      "source": [
        "dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}\n",
        "class_names = image_datasets['train'].classes"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Kht4S1JCPmhl"
      },
      "source": [
        "print(class_names)\n",
        "print(dataset_sizes)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GUETQ0wIP1sh"
      },
      "source": [
        "pretrained_model.to(device)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_csze4BinO4R"
      },
      "source": [
        "Run the training for 5 epochs."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yHOxrotFPsEO"
      },
      "source": [
        "pretrained_model = train_model(pretrained_model, criterion, optimizer, exp_lr_scheduler,\n",
        "                       num_epochs=5)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pBsnUFQqfxhA"
      },
      "source": [
        "torch.save(pretrained_model.state_dict(), \"./dog_breed_classification.pth\")\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_8BhEM09nUzE"
      },
      "source": [
        "Load the trained/saved model for sanity check and testing on the test dataset. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "US5Q642Yf0nZ"
      },
      "source": [
        "pretrained_model = torch.hub.load('pytorch/vision', 'resnet50')\n",
        "pretrained_model.fc = pretrained_model.fc = nn.Linear(num_ftrs, 133)\n",
        "pretrained_model.load_state_dict(torch.load('./dog_breed_classification.pth'))\n",
        "pretrained_model.eval()\n",
        "pretrained_model.to(device)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QJZnTzeQgEv3"
      },
      "source": [
        "def model_test(model):\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    with torch.no_grad():\n",
        "        for images, labels in dataloaders['test']:\n",
        "            images, labels = images.to(device), labels.to(device)\n",
        "            outputs = model(images)\n",
        "            _, predicted = torch.max(outputs.data, 1)\n",
        "            total += labels.size(0)\n",
        "            correct += (predicted == labels).sum().item()\n",
        "    print('correct: {:d}  total: {:d}'.format(correct, total))\n",
        "    print('accuracy = {:f}'.format(correct / total))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Tj6a3qOtgdIv"
      },
      "source": [
        "model_test(pretrained_model)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rACobIG9lv96"
      },
      "source": [
        "Running some prediction tests. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-skO-TvNgyJO"
      },
      "source": [
        "def prediction(model, filename):\n",
        "    labels = class_names\n",
        "    img = Image.open(filename)\n",
        "    img = data_transforms['test'](img)\n",
        "    img = img.unsqueeze(0)\n",
        "    prediction = model(img.to(device))\n",
        "    prediction = prediction.argmax()\n",
        "    print(labels[prediction])\n",
        "    \n",
        "prediction(pretrained_model, '/content/dogImages/valid/005.Alaskan_malamute/Alaskan_malamute_00298.jpg')\n",
        "prediction(pretrained_model, '/content/dogImages/valid/005.Alaskan_malamute/Alaskan_malamute_00344.jpg')\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VZbxP1UUndmH"
      },
      "source": [
        "Downloading the trained model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gAjFbcez7PXb"
      },
      "source": [
        "from google.colab import files\n",
        "files.download('./dog_breed_classification.pth') "
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}